import unittest

import torch

SKIP_TEST = None
try:
    from apex.contrib.transducer import TransducerLoss
    from apex.contrib.transducer import _transducer_ref as transducer_ref
except ImportError as e:
    SKIP_TEST = e


@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}")
class TransducerLossTest(unittest.TestCase):
    def setUp(self, seed=1234):
        torch.manual_seed(seed)

    def gen_input(self, scalar_t, for_vector_kernel):
        self.B = 5
        T_min = 23
        T_max = 51
        U_min = 12
        U_max = 25
        V = 16 if for_vector_kernel else 14
        self.blank_idx = V - 1
        device = "cuda"

        self.x_tst = torch.randn(
            (self.B, T_max, U_max, V), dtype=scalar_t, requires_grad=True, device=device
        )
        self.y = torch.randint(
            0, self.blank_idx, (self.B, U_max - 1), dtype=torch.int, device=device
        )
        self.f_len = torch.randint(T_min, T_max + 1, (self.B,), dtype=torch.int, device=device)
        self.y_len = torch.randint(U_min - 1, U_max, (self.B,), dtype=torch.int, device=device)
        self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max
        self.y_len[torch.randint(0, self.B, (1,)).item()] = U_max - 1
        self.x_tst_packed, self.batch_offset = self._pack(self.x_tst)
        # Generate reference
        x_ref = self.x_tst.data.clone()
        x_ref.requires_grad = True
        loss_grad = torch.ones(x_ref.size(0), dtype=x_ref.dtype, device=x_ref.device) / x_ref.size(
            0
        )
        _, _, self.grad_ref, self.loss_ref = transducer_ref.transducer_loss_reference(
            x=x_ref,
            label=self.y,
            f_len=self.f_len,
            y_len=self.y_len,
            blank_idx=self.blank_idx,
            loss_grad=loss_grad,
        )

    def _pack(self, x):
        list_x = []
        for b in range(self.B):
            list_x_row = [x[b, t, : self.y_len[b] + 1] for t in range(self.f_len[b])]
            x_row = torch.cat(list_x_row)
            list_x.append(x_row)
        x_packed = torch.cat(list_x).data.clone()
        x_packed.requires_grad = True
        batch_offset = torch.cumsum(self.f_len * (self.y_len + 1), dim=0)
        return x_packed, batch_offset

    def _unpack(self, x):
        x_unpacked = torch.zeros(
            self.B,
            self.f_len.max(),
            self.y_len.max() + 1,
            x.size(-1),
            dtype=x.dtype,
            device=x.device,
        )
        for b in range(self.B):
            my_batch_offset = 0 if b == 0 else self.batch_offset[b - 1]
            my_f_len = self.f_len[b]
            my_g_len = self.y_len[b] + 1
            for t in range(my_f_len):
                for u in range(my_g_len):
                    x_unpacked[b, t, u] = x[my_batch_offset + t * my_g_len + u]
        return x_unpacked

    def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input, for_vector_kernel):
        self.gen_input(scalar_t, for_vector_kernel)
        my_loss = TransducerLoss(
            fuse_softmax_backward=fuse_softmax_backward, packed_input=packed_input
        )
        if not packed_input:
            loss_tst = my_loss(
                x=self.x_tst,
                label=self.y,
                f_len=self.f_len,
                y_len=self.y_len,
                blank_idx=self.blank_idx,
            )
            loss_tst.mean().backward()
            grad_tst = self.x_tst.grad
        else:
            loss_tst = my_loss(
                x=self.x_tst_packed,
                label=self.y,
                f_len=self.f_len,
                y_len=self.y_len,
                blank_idx=self.blank_idx,
                batch_offset=self.batch_offset,
                max_f_len=max(self.f_len),
            )
            loss_tst.mean().backward()
            grad_tst_packed = self.x_tst_packed.grad
            grad_tst = self._unpack(grad_tst_packed)

        return loss_tst, grad_tst

    def test_transducer_loss_fp32(self):
        loss_tst, grad_tst = self.run_transducer_loss(
            scalar_t=torch.float32,
            fuse_softmax_backward=False,
            packed_input=False,
            for_vector_kernel=False,
        )
        torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)
        torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5)

    def test_transducer_loss_fp16(self):
        loss_tst, grad_tst = self.run_transducer_loss(
            scalar_t=torch.float16,
            fuse_softmax_backward=False,
            packed_input=False,
            for_vector_kernel=False,
        )
        torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)
        torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)

    def test_transducer_loss_fp16_backward_fusion(self):
        loss_tst, grad_tst = self.run_transducer_loss(
            scalar_t=torch.float16,
            fuse_softmax_backward=True,
            packed_input=False,
            for_vector_kernel=False,
        )
        torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)
        torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)

    def test_transducer_loss_fp16_backward_fusion_packed(self):
        loss_tst, grad_tst = self.run_transducer_loss(
            scalar_t=torch.float16,
            fuse_softmax_backward=True,
            packed_input=True,
            for_vector_kernel=False,
        )
        torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)
        torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)

    def test_transducer_loss_fp16_backward_fusion_packed_vec(self):
        loss_tst, grad_tst = self.run_transducer_loss(
            scalar_t=torch.float16,
            fuse_softmax_backward=True,
            packed_input=True,
            for_vector_kernel=True,
        )
        torch.testing.assert_close(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)
        torch.testing.assert_close(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)


if __name__ == "__main__":
    unittest.main()
